{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = '1'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "import json"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('dataset-bpe.json') as fopen:\n",
    "    data = json.load(fopen)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_X = data['train_X']\n",
    "train_Y = data['train_Y']\n",
    "test_X = data['test_X']\n",
    "test_Y = data['test_Y']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "EOS = 2\n",
    "GO = 1\n",
    "vocab_size = 32000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_Y = [i + [2] for i in train_Y]\n",
    "test_Y = [i + [2] for i in test_Y]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tensor2tensor.utils import beam_search\n",
    "\n",
    "def pad_second_dim(x, desired_size):\n",
    "    padding = tf.tile([[[0.0]]], tf.stack([tf.shape(x)[0], desired_size - tf.shape(x)[1], tf.shape(x)[2]], 0))\n",
    "    return tf.concat([x, padding], 1)\n",
    "\n",
    "def hop_forward(memory_o, memory_i, response_proj, inputs_len, questions_len):\n",
    "    match = memory_i\n",
    "    match = pre_softmax_masking(match, inputs_len)\n",
    "    match = tf.nn.softmax(match)\n",
    "    match = post_softmax_masking(match, questions_len)\n",
    "    response = tf.multiply(match, memory_o)\n",
    "    return response_proj(response)\n",
    "\n",
    "\n",
    "def pre_softmax_masking(x, seq_len):\n",
    "    paddings = tf.fill(tf.shape(x), float('-inf'))\n",
    "    T = tf.shape(x)[1]\n",
    "    max_seq_len = tf.shape(x)[2]\n",
    "    masks = tf.sequence_mask(seq_len, max_seq_len, dtype = tf.float32)\n",
    "    masks = tf.tile(tf.expand_dims(masks, 1), [1, T, 1])\n",
    "    return tf.where(tf.equal(masks, 0), paddings, x)\n",
    "\n",
    "\n",
    "def post_softmax_masking(x, seq_len):\n",
    "    T = tf.shape(x)[2]\n",
    "    max_seq_len = tf.shape(x)[1]\n",
    "    masks = tf.sequence_mask(seq_len, max_seq_len, dtype = tf.float32)\n",
    "    masks = tf.tile(tf.expand_dims(masks, -1), [1, 1, T])\n",
    "    return x * masks\n",
    "\n",
    "def embed_seq(x, vocab_size, zero_pad = True):\n",
    "    lookup_table = tf.get_variable(\n",
    "        'lookup_table', [vocab_size, size_layer], tf.float32\n",
    "    )\n",
    "    if zero_pad:\n",
    "        lookup_table = tf.concat(\n",
    "            (tf.zeros([1, size_layer]), lookup_table[1:, :]), axis = 0\n",
    "        )\n",
    "    return tf.nn.embedding_lookup(lookup_table, x)\n",
    "\n",
    "def sinusoidal_position_encoding(inputs, mask, repr_dim):\n",
    "    T = tf.shape(inputs)[1]\n",
    "    pos = tf.reshape(tf.range(0.0, tf.to_float(T), dtype=tf.float32), [-1, 1])\n",
    "    i = np.arange(0, repr_dim, 2, np.float32)\n",
    "    denom = np.reshape(np.power(10000.0, i / repr_dim), [1, -1])\n",
    "    enc = tf.expand_dims(tf.concat([tf.sin(pos / denom), tf.cos(pos / denom)], 1), 0)\n",
    "    return tf.tile(enc, [tf.shape(inputs)[0], 1, 1]) * tf.expand_dims(tf.to_float(mask), -1)\n",
    "\n",
    "def quest_mem(x, vocab_size, size_layer):\n",
    "    en_masks = tf.sign(x)\n",
    "    x = embed_seq(x, vocab_size)\n",
    "    x += sinusoidal_position_encoding(x, en_masks, size_layer)\n",
    "    return x\n",
    "\n",
    "class Translator:\n",
    "    def __init__(self, size_layer, num_layers, embedded_size, learning_rate,\n",
    "                beam_width = 5, n_hops = 3):\n",
    "        \n",
    "        self.X = tf.placeholder(tf.int32, [None, None])\n",
    "        self.Y = tf.placeholder(tf.int32, [None, None])\n",
    "        \n",
    "        self.X_seq_len = tf.count_nonzero(self.X, 1, dtype = tf.int32)\n",
    "        self.Y_seq_len = tf.count_nonzero(self.Y, 1, dtype = tf.int32)\n",
    "        batch_size = tf.shape(self.X)[0]\n",
    "        \n",
    "        lookup_table = tf.get_variable('lookup_table', [vocab_size, size_layer], tf.float32)\n",
    "        \n",
    "        with tf.variable_scope('memory_o'):\n",
    "            memory_o = quest_mem(self.X, vocab_size, size_layer)\n",
    "        \n",
    "        with tf.variable_scope('memory_i'):\n",
    "            memory_i = quest_mem(self.X, vocab_size, size_layer)\n",
    "            \n",
    "        with tf.variable_scope('interaction'):\n",
    "            response_proj = tf.layers.Dense(size_layer)\n",
    "            for _ in range(n_hops):\n",
    "                answer = hop_forward(memory_o,\n",
    "                                     memory_i,\n",
    "                                     response_proj,\n",
    "                                     self.X_seq_len,\n",
    "                                     self.X_seq_len)\n",
    "                memory_i = answer\n",
    "                \n",
    "        def cells(reuse=False):\n",
    "            return tf.nn.rnn_cell.LSTMCell(size_layer,initializer=tf.orthogonal_initializer(),reuse=reuse)\n",
    "        \n",
    "        main = tf.strided_slice(self.Y, [0, 0], [batch_size, -1], [1, 1])\n",
    "        decoder_input = tf.concat([tf.fill([batch_size, 1], GO), main], 1)\n",
    "        dense = tf.layers.Dense(vocab_size)\n",
    "        decoder_cells = tf.nn.rnn_cell.MultiRNNCell([cells() for _ in range(num_layers)])\n",
    "\n",
    "        init_state = answer[:,-1]\n",
    "        encoder_state = tf.nn.rnn_cell.LSTMStateTuple(c=init_state, h=init_state)\n",
    "        encoder_state = tuple([encoder_state] * num_layers)\n",
    "        \n",
    "        print(encoder_state)\n",
    "        vocab_proj = tf.layers.Dense(vocab_size)\n",
    "        \n",
    "        helper = tf.contrib.seq2seq.TrainingHelper(\n",
    "            inputs = tf.nn.embedding_lookup(lookup_table, decoder_input),\n",
    "            sequence_length = tf.to_int32(self.Y_seq_len))\n",
    "        \n",
    "        decoder = tf.contrib.seq2seq.BasicDecoder(cell = decoder_cells,\n",
    "                                                  helper = helper,\n",
    "                                                  initial_state = encoder_state,\n",
    "                                                  output_layer = vocab_proj)\n",
    "        \n",
    "        decoder_output, _, _ = tf.contrib.seq2seq.dynamic_decode(decoder = decoder,\n",
    "                                                                maximum_iterations = tf.reduce_max(self.Y_seq_len))\n",
    "        \n",
    "        helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(embedding = lookup_table,\n",
    "                                                          start_tokens = tf.tile(\n",
    "                                                              tf.constant([GO], \n",
    "                                                                          dtype=tf.int32), \n",
    "                                                              [tf.shape(init_state)[0]]),\n",
    "                                                          end_token = EOS)\n",
    "        decoder = tf.contrib.seq2seq.BasicDecoder(\n",
    "            cell = decoder_cells,\n",
    "            helper = helper,\n",
    "            initial_state = encoder_state,\n",
    "            output_layer = vocab_proj)\n",
    "        predicting_decoder_output, _, _ = tf.contrib.seq2seq.dynamic_decode(\n",
    "            decoder = decoder,\n",
    "            maximum_iterations = 2 * tf.reduce_max(self.X_seq_len))\n",
    "        self.training_logits = decoder_output.rnn_output\n",
    "        self.logits = decoder_output.sample_id\n",
    "        self.fast_result = predicting_decoder_output.sample_id\n",
    "        \n",
    "        masks = tf.sequence_mask(self.Y_seq_len, tf.reduce_max(self.Y_seq_len), dtype=tf.float32)\n",
    "        self.cost = tf.contrib.seq2seq.sequence_loss(logits = self.training_logits,\n",
    "                                                     targets = self.Y,\n",
    "                                                     weights = masks)\n",
    "        self.optimizer = tf.train.AdamOptimizer(learning_rate).minimize(self.cost)\n",
    "        y_t = tf.argmax(self.training_logits,axis=2)\n",
    "        y_t = tf.cast(y_t, tf.int32)\n",
    "        self.prediction = tf.boolean_mask(y_t, masks)\n",
    "        mask_label = tf.boolean_mask(self.Y, masks)\n",
    "        correct_pred = tf.equal(self.prediction, mask_label)\n",
    "        correct_index = tf.cast(correct_pred, tf.float32)\n",
    "        self.accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "size_layer = 512\n",
    "num_layers = 2\n",
    "embedded_size = 256\n",
    "learning_rate = 1e-3\n",
    "batch_size = 128\n",
    "epoch = 20"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/husein/.local/lib/python3.6/site-packages/tensorflow_core/python/client/session.py:1750: UserWarning: An interactive session is already active. This can cause out-of-memory errors in some cases. You must explicitly call `InteractiveSession.close()` to release resources held by the other session(s).\n",
      "  warnings.warn('An interactive session is already active. This can '\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(LSTMStateTuple(c=<tf.Tensor 'strided_slice_1:0' shape=(?, 512) dtype=float32>, h=<tf.Tensor 'strided_slice_1:0' shape=(?, 512) dtype=float32>), LSTMStateTuple(c=<tf.Tensor 'strided_slice_1:0' shape=(?, 512) dtype=float32>, h=<tf.Tensor 'strided_slice_1:0' shape=(?, 512) dtype=float32>))\n"
     ]
    }
   ],
   "source": [
    "tf.reset_default_graph()\n",
    "sess = tf.InteractiveSession()\n",
    "model = Translator(size_layer, num_layers, embedded_size, learning_rate)\n",
    "sess.run(tf.global_variables_initializer())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "pad_sequences = tf.keras.preprocessing.sequence.pad_sequences"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[array([[19397, 19397, 27768, 27768, 17696, 17696, 27768, 25642, 25642,\n",
       "         27768, 28355, 28503, 14535, 17902, 17902, 14535, 17254, 24473,\n",
       "         24473,  7048, 15065, 17168, 17168, 21212, 21212, 22429, 22429,\n",
       "         22429, 25737,  3915,  3915, 11557, 11557, 23311, 10254,  6953,\n",
       "          6953, 10254,  3712,  3712, 30643, 22712, 22712, 22712,  4579,\n",
       "          4579, 31011, 31011, 31011, 31011,  6226,   822, 23311, 25129,\n",
       "         25129, 20665,  9644,  9644, 31653, 31653, 31653, 31653, 21142,\n",
       "         21142, 21142, 23120, 23095, 26751, 13780, 13780, 13780, 16678],\n",
       "        [ 2823, 21088, 14140, 14140, 26778,  1178,  1178, 28964, 28964,\n",
       "         28964, 11199, 11199, 24198, 10398, 27551, 10398, 18156, 18529,\n",
       "          8502,  8502, 18146,  2039, 21595, 21595,  2379,  2379,  2379,\n",
       "          9760,  9760,  9760, 24497, 17615, 17615,  4110,  8817,  8817,\n",
       "         24513, 24513, 10747, 27780, 27780, 25660, 25660, 25660, 27313,\n",
       "         27313,  7231, 23454,  9132,  9132,  9132, 25565, 13134, 16748,\n",
       "         16748, 27633, 27633, 27633, 20065, 27633,  3558,  3558, 13439,\n",
       "         13439, 13439, 29486, 12549, 28570, 28570, 11940, 11940, 10776],\n",
       "        [19397, 19397, 27768, 27768, 17696, 17696, 27768, 25642, 25642,\n",
       "         27768, 28355, 28503, 14535, 17902, 17902, 14535, 17254, 24473,\n",
       "         24473,  7048, 15065, 17168, 17168, 21212, 21212, 22429, 22429,\n",
       "         22429, 25737,  3915,  3915, 11557, 11557, 23311, 10254,  6953,\n",
       "          6953, 10254,  3712,  3712, 30643, 22712, 22712, 22712,  4579,\n",
       "          4579, 31011, 31011, 31011, 31011,  6226,   822, 23311, 25129,\n",
       "         25129, 20665,  9644,  9644, 31653, 31653, 31653, 31653, 21142,\n",
       "         21142, 21142, 23120, 23095, 26751, 13780, 13780, 13780, 16678],\n",
       "        [19397, 19397, 27768, 27768, 17696, 17696, 27768, 25642, 25642,\n",
       "         27768, 28355, 28503, 14535, 17902, 17902, 14535, 17254, 24473,\n",
       "         24473,  7048, 15065, 17168, 17168, 21212, 21212, 22429, 22429,\n",
       "         22429, 25737,  3915,  3915, 11557, 11557, 23311, 10254,  6953,\n",
       "          6953, 10254,  3712,  3712, 30643, 22712, 22712, 22712,  4579,\n",
       "          4579, 31011, 31011, 31011, 31011,  6226,   822, 23311, 25129,\n",
       "         25129, 20665,  9644,  9644, 31653, 31653, 31653, 31653, 21142,\n",
       "         21142, 21142, 23120, 23095, 26751, 13780, 13780, 13780, 16678],\n",
       "        [19397, 19397, 27768, 27768, 17696, 17696, 27768, 25642, 25642,\n",
       "         27768, 28355, 28503, 14535, 17902, 17902, 14535, 17254, 24473,\n",
       "         24473,  7048, 15065, 17168, 17168, 21212, 21212, 22429, 22429,\n",
       "         22429, 25737,  3915,  3915, 11557, 11557, 23311, 10254,  6953,\n",
       "          6953, 10254,  3712,  3712, 30643, 22712, 22712, 22712,  4579,\n",
       "          4579, 31011, 31011, 31011, 31011,  6226,   822, 23311, 25129,\n",
       "         25129, 20665,  9644,  9644, 31653, 31653, 31653, 31653, 21142,\n",
       "         21142, 21142, 23120, 23095, 26751, 13780, 13780, 13780, 16678],\n",
       "        [19397, 19397, 27768, 27768, 17696, 17696, 27768, 25642, 25642,\n",
       "         27768, 28355, 28503, 14535, 17902, 17902, 14535, 17254, 24473,\n",
       "         24473,  7048, 15065, 17168, 17168, 21212, 21212, 22429, 22429,\n",
       "         22429, 25737,  3915,  3915, 11557, 11557, 23311, 10254,  6953,\n",
       "          6953, 10254,  3712,  3712, 30643, 22712, 22712, 22712,  4579,\n",
       "          4579, 31011, 31011, 31011, 31011,  6226,   822, 23311, 25129,\n",
       "         25129, 20665,  9644,  9644, 31653, 31653, 31653, 31653, 21142,\n",
       "         21142, 21142, 23120, 23095, 26751, 13780, 13780, 13780, 16678],\n",
       "        [19397, 19397, 27768, 27768, 17696, 17696, 27768, 25642, 25642,\n",
       "         27768, 28355, 28503, 14535, 17902, 17902, 14535, 17254, 24473,\n",
       "         24473,  7048, 15065, 17168, 17168, 21212, 21212, 22429, 22429,\n",
       "         22429, 25737,  3915,  3915, 11557, 11557, 23311, 10254,  6953,\n",
       "          6953, 10254,  3712,  3712, 30643, 22712, 22712, 22712,  4579,\n",
       "          4579, 31011, 31011, 31011, 31011,  6226,   822, 23311, 25129,\n",
       "         25129, 20665,  9644,  9644, 31653, 31653, 31653, 31653, 21142,\n",
       "         21142, 21142, 23120, 23095, 26751, 13780, 13780, 13780, 16678],\n",
       "        [19397, 19397, 27768, 27768, 17696, 17696, 27768, 25642, 25642,\n",
       "         27768, 28355, 28503, 14535, 17902, 17902, 14535, 17254, 24473,\n",
       "         24473,  7048, 15065, 17168, 17168, 21212, 21212, 22429, 22429,\n",
       "         22429, 25737,  3915,  3915, 11557, 11557, 23311, 10254,  6953,\n",
       "          6953, 10254,  3712,  3712, 30643, 22712, 22712, 22712,  4579,\n",
       "          4579, 31011, 31011, 31011, 31011,  6226,   822, 23311, 25129,\n",
       "         25129, 20665,  9644,  9644, 31653, 31653, 31653, 31653, 21142,\n",
       "         21142, 21142, 23120, 23095, 26751, 13780, 13780, 13780, 16678],\n",
       "        [19397, 19397, 27768, 27768, 17696, 17696, 27768, 25642, 25642,\n",
       "         27768, 28355, 28503, 14535, 17902, 17902, 14535, 17254, 24473,\n",
       "         24473,  7048, 15065, 17168, 17168, 21212, 21212, 22429, 22429,\n",
       "         22429, 25737,  3915,  3915, 11557, 11557, 23311, 10254,  6953,\n",
       "          6953, 10254,  3712,  3712, 30643, 22712, 22712, 22712,  4579,\n",
       "          4579, 31011, 31011, 31011, 31011,  6226,   822, 23311, 25129,\n",
       "         25129, 20665,  9644,  9644, 31653, 31653, 31653, 31653, 21142,\n",
       "         21142, 21142, 23120, 23095, 26751, 13780, 13780, 13780, 16678],\n",
       "        [19397, 19397, 27768, 27768, 17696, 17696, 27768, 25642, 25642,\n",
       "         27768, 28355, 28503, 14535, 17902, 17902, 14535, 17254, 24473,\n",
       "         24473,  7048, 15065, 17168, 17168, 21212, 21212, 22429, 22429,\n",
       "         22429, 25737,  3915,  3915, 11557, 11557, 23311, 10254,  6953,\n",
       "          6953, 10254,  3712,  3712, 30643, 22712, 22712, 22712,  4579,\n",
       "          4579, 31011, 31011, 31011, 31011,  6226,   822, 23311, 25129,\n",
       "         25129, 20665,  9644,  9644, 31653, 31653, 31653, 31653, 21142,\n",
       "         21142, 21142, 23120, 23095, 26751, 13780, 13780, 13780, 16678]],\n",
       "       dtype=int32), 10.3734865, 0.0]"
      ]
     },
     "execution_count": 33,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "batch_x = pad_sequences(train_X[:10], padding='post')\n",
    "batch_y = pad_sequences(train_Y[:10], padding='post')\n",
    "\n",
    "sess.run([model.fast_result, model.cost, model.accuracy], \n",
    "         feed_dict = {model.X: batch_x, model.Y: batch_y})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [06:46<00:00,  3.85it/s, accuracy=0.136, cost=5.99]\n",
      "minibatch loop: 100%|██████████| 40/40 [00:05<00:00,  7.31it/s, accuracy=0.151, cost=5.5] \n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 1, training avg loss 6.897701, training avg acc 0.094511\n",
      "epoch 1, testing avg loss 5.886279, testing avg acc 0.144228\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [06:45<00:00,  3.85it/s, accuracy=0.176, cost=5.32]\n",
      "minibatch loop: 100%|██████████| 40/40 [00:05<00:00,  7.51it/s, accuracy=0.204, cost=4.79]\n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 2, training avg loss 5.496175, training avg acc 0.165414\n",
      "epoch 2, testing avg loss 5.236630, testing avg acc 0.178004\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [06:45<00:00,  3.86it/s, accuracy=0.206, cost=4.95]\n",
      "minibatch loop: 100%|██████████| 40/40 [00:05<00:00,  7.67it/s, accuracy=0.22, cost=4.51] \n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 3, training avg loss 5.038396, training avg acc 0.190332\n",
      "epoch 3, testing avg loss 4.943071, testing avg acc 0.197108\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [06:44<00:00,  3.87it/s, accuracy=0.225, cost=4.7] \n",
      "minibatch loop: 100%|██████████| 40/40 [00:05<00:00,  7.63it/s, accuracy=0.242, cost=4.34]\n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 4, training avg loss 4.765566, training avg acc 0.208819\n",
      "epoch 4, testing avg loss 4.766912, testing avg acc 0.211949\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [06:45<00:00,  3.85it/s, accuracy=0.239, cost=4.5] \n",
      "minibatch loop: 100%|██████████| 40/40 [00:05<00:00,  7.68it/s, accuracy=0.253, cost=4.22]\n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 5, training avg loss 4.574376, training avg acc 0.223101\n",
      "epoch 5, testing avg loss 4.658642, testing avg acc 0.221847\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [06:45<00:00,  3.85it/s, accuracy=0.246, cost=4.35]\n",
      "minibatch loop: 100%|██████████| 40/40 [00:05<00:00,  7.56it/s, accuracy=0.253, cost=4.14]\n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 6, training avg loss 4.432821, training avg acc 0.233963\n",
      "epoch 6, testing avg loss 4.589426, testing avg acc 0.228429\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [06:45<00:00,  3.85it/s, accuracy=0.253, cost=4.21]\n",
      "minibatch loop: 100%|██████████| 40/40 [00:05<00:00,  7.66it/s, accuracy=0.28, cost=4.09] \n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 7, training avg loss 4.322593, training avg acc 0.242401\n",
      "epoch 7, testing avg loss 4.546140, testing avg acc 0.233800\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [06:46<00:00,  3.85it/s, accuracy=0.255, cost=4.09]\n",
      "minibatch loop: 100%|██████████| 40/40 [00:05<00:00,  7.67it/s, accuracy=0.29, cost=4.05] \n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 8, training avg loss 4.231865, training avg acc 0.249445\n",
      "epoch 8, testing avg loss 4.521348, testing avg acc 0.236877\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [06:46<00:00,  3.85it/s, accuracy=0.265, cost=3.98]\n",
      "minibatch loop: 100%|██████████| 40/40 [00:05<00:00,  7.67it/s, accuracy=0.296, cost=4]   \n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 9, training avg loss 4.154636, training avg acc 0.255596\n",
      "epoch 9, testing avg loss 4.501642, testing avg acc 0.239343\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [06:46<00:00,  3.85it/s, accuracy=0.275, cost=3.89]\n",
      "minibatch loop: 100%|██████████| 40/40 [00:05<00:00,  7.64it/s, accuracy=0.296, cost=3.98]\n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 10, training avg loss 4.086490, training avg acc 0.261151\n",
      "epoch 10, testing avg loss 4.492503, testing avg acc 0.241737\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [06:46<00:00,  3.85it/s, accuracy=0.284, cost=3.8] \n",
      "minibatch loop: 100%|██████████| 40/40 [00:05<00:00,  7.66it/s, accuracy=0.301, cost=3.98]\n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 11, training avg loss 4.025933, training avg acc 0.266182\n",
      "epoch 11, testing avg loss 4.490067, testing avg acc 0.243361\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [06:46<00:00,  3.85it/s, accuracy=0.285, cost=3.73]\n",
      "minibatch loop: 100%|██████████| 40/40 [00:05<00:00,  7.65it/s, accuracy=0.301, cost=3.97]\n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 12, training avg loss 3.971081, training avg acc 0.270903\n",
      "epoch 12, testing avg loss 4.493936, testing avg acc 0.244468\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [06:46<00:00,  3.85it/s, accuracy=0.296, cost=3.66]\n",
      "minibatch loop: 100%|██████████| 40/40 [00:05<00:00,  7.67it/s, accuracy=0.296, cost=3.97]\n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 13, training avg loss 3.920249, training avg acc 0.275454\n",
      "epoch 13, testing avg loss 4.502479, testing avg acc 0.245590\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [06:46<00:00,  3.85it/s, accuracy=0.303, cost=3.6] \n",
      "minibatch loop: 100%|██████████| 40/40 [00:05<00:00,  7.69it/s, accuracy=0.285, cost=3.97]\n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 14, training avg loss 3.873923, training avg acc 0.279623\n",
      "epoch 14, testing avg loss 4.512751, testing avg acc 0.245827\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [06:46<00:00,  3.85it/s, accuracy=0.312, cost=3.54]\n",
      "minibatch loop: 100%|██████████| 40/40 [00:05<00:00,  7.65it/s, accuracy=0.285, cost=3.97]\n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 15, training avg loss 3.831019, training avg acc 0.283554\n",
      "epoch 15, testing avg loss 4.526254, testing avg acc 0.245474\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [06:46<00:00,  3.85it/s, accuracy=0.322, cost=3.48]\n",
      "minibatch loop: 100%|██████████| 40/40 [00:05<00:00,  7.68it/s, accuracy=0.28, cost=3.96] \n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 16, training avg loss 3.789867, training avg acc 0.287467\n",
      "epoch 16, testing avg loss 4.540810, testing avg acc 0.245914\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [07:19<00:00,  3.56it/s, accuracy=0.33, cost=3.42] \n",
      "minibatch loop: 100%|██████████| 40/40 [00:05<00:00,  7.74it/s, accuracy=0.285, cost=3.96]\n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 17, training avg loss 3.750946, training avg acc 0.291335\n",
      "epoch 17, testing avg loss 4.559640, testing avg acc 0.246204\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [07:18<00:00,  3.56it/s, accuracy=0.336, cost=3.36]\n",
      "minibatch loop: 100%|██████████| 40/40 [00:05<00:00,  7.64it/s, accuracy=0.274, cost=3.97]\n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 18, training avg loss 3.714676, training avg acc 0.294952\n",
      "epoch 18, testing avg loss 4.579117, testing avg acc 0.245720\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [06:46<00:00,  3.85it/s, accuracy=0.341, cost=3.33]\n",
      "minibatch loop: 100%|██████████| 40/40 [00:05<00:00,  7.64it/s, accuracy=0.29, cost=3.95] \n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 19, training avg loss 3.680817, training avg acc 0.298401\n",
      "epoch 19, testing avg loss 4.608824, testing avg acc 0.246128\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [06:46<00:00,  3.85it/s, accuracy=0.349, cost=3.26]\n",
      "minibatch loop: 100%|██████████| 40/40 [00:05<00:00,  7.65it/s, accuracy=0.29, cost=3.93] "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 20, training avg loss 3.650501, training avg acc 0.301656\n",
      "epoch 20, testing avg loss 4.604999, testing avg acc 0.246473\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "import tqdm\n",
    "\n",
    "for e in range(epoch):\n",
    "    pbar = tqdm.tqdm(\n",
    "        range(0, len(train_X), batch_size), desc = 'minibatch loop')\n",
    "    train_loss, train_acc, test_loss, test_acc = [], [], [], []\n",
    "    for i in pbar:\n",
    "        index = min(i + batch_size, len(train_X))\n",
    "        batch_x = pad_sequences(train_X[i : index], padding='post')\n",
    "        batch_y = pad_sequences(train_Y[i : index], padding='post')\n",
    "        feed = {model.X: batch_x,\n",
    "                model.Y: batch_y}\n",
    "        accuracy, loss, _ = sess.run([model.accuracy,model.cost,model.optimizer],\n",
    "                                    feed_dict = feed)\n",
    "        train_loss.append(loss)\n",
    "        train_acc.append(accuracy)\n",
    "        pbar.set_postfix(cost = loss, accuracy = accuracy)\n",
    "    \n",
    "    \n",
    "    pbar = tqdm.tqdm(\n",
    "        range(0, len(test_X), batch_size), desc = 'minibatch loop')\n",
    "    for i in pbar:\n",
    "        index = min(i + batch_size, len(test_X))\n",
    "        batch_x = pad_sequences(test_X[i : index], padding='post')\n",
    "        batch_y = pad_sequences(test_Y[i : index], padding='post')\n",
    "        feed = {model.X: batch_x,\n",
    "                model.Y: batch_y,}\n",
    "        accuracy, loss = sess.run([model.accuracy,model.cost],\n",
    "                                    feed_dict = feed)\n",
    "\n",
    "        test_loss.append(loss)\n",
    "        test_acc.append(accuracy)\n",
    "        pbar.set_postfix(cost = loss, accuracy = accuracy)\n",
    "    \n",
    "    print('epoch %d, training avg loss %f, training avg acc %f'%(e+1,\n",
    "                                                                 np.mean(train_loss),np.mean(train_acc)))\n",
    "    print('epoch %d, testing avg loss %f, testing avg acc %f'%(e+1,\n",
    "                                                              np.mean(test_loss),np.mean(test_acc)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tensor2tensor.utils import bleu_hook"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 40/40 [00:08<00:00,  4.62it/s]\n"
     ]
    }
   ],
   "source": [
    "results = []\n",
    "for i in tqdm.tqdm(range(0, len(test_X), batch_size)):\n",
    "    index = min(i + batch_size, len(test_X))\n",
    "    batch_x = pad_sequences(test_X[i : index], padding='post')\n",
    "    feed = {model.X: batch_x}\n",
    "    p = sess.run(model.fast_result,feed_dict = feed)\n",
    "    result = []\n",
    "    for row in p:\n",
    "        result.append([i for i in row if i > 3])\n",
    "    results.extend(result)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [],
   "source": [
    "rights = []\n",
    "for r in test_Y:\n",
    "    rights.append([i for i in r if i > 3])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
